{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "How can we sample *some* of the variables in a pyll configuration space, while assigning values to the others?\n",
    "\n",
    "Let's look at a simple example involving 2 variables 'a' and 'b'.\n",
    "The 'a' variable controls whether our space returns -1 or some random number, 'b'.\n",
    "\n",
    "If we just run optimization normally, then we'll find that 'a' should be 0 (the index of the choice that gives the lowest\n",
    "return value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 1413.57trial/s, best loss: -1.0]\n",
      "{'a': 0}\n"
     ]
    }
   ],
   "source": [
    "from hyperopt import hp, fmin, rand\n",
    "space = hp.choice('a', [-1, hp.uniform('b', 0, 1)])\n",
    "best = fmin(fn=lambda x: x, space=space, algo=rand.suggest, max_evals=100)\n",
    "print(best)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But what if someone else already set up the space, and we just run the search over the other part of the space, which corresponds to the uniform draw?\n",
    "\n",
    "The easiest way to do this is probably to *clone* the search space, while making some substitutions while we're at it.\n",
    "We can just make a new search space in which 'a' is no longer a hyperparameter."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 switch\n",
      "1   hyperopt_param\n",
      "2     Literal{a}\n",
      "3     randint\n",
      "4       Literal{2}\n",
      "5   Literal{-1}\n",
      "6   float\n",
      "7     hyperopt_param\n",
      "8       Literal{b}\n",
      "9       uniform\n",
      "10         Literal{0}\n",
      "11         Literal{1}\n"
     ]
    }
   ],
   "source": [
    "# put the configuration space in a local var\n",
    "# so that we can work on it.\n",
    "print(space)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The transformation we want to make on the search space is to replace the `randint` with a constant value of 1, \n",
    "corresponding to always choosing hyperparameter a to be the second element of the list of choices.\n",
    "\n",
    "Now, if you don't have access to the code that generated a search space, then you'll have to go digging around for the\n",
    "node you need to replace. There are two approaches you can use to do this: navigation and search."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "by navigation:\n",
      "0 randint\n",
      "1   Literal{2}\n",
      "by search:\n",
      "0 randint\n",
      "1   Literal{2}\n"
     ]
    }
   ],
   "source": [
    "from hyperopt import pyll\n",
    "\n",
    "# The \"navigation\" approach to finding an internal\n",
    "# search space node:\n",
    "randint_node_nav = space.pos_args[0].pos_args[1]\n",
    "print(\"by navigation:\")\n",
    "print(randint_node_nav)\n",
    "\n",
    "# The \"search\" approach to finding an internal\n",
    "# search space node:\n",
    "randint_nodes = [node for node in pyll.dfs(space) if node.name == 'randint']\n",
    "randint_node_srch, = randint_nodes\n",
    "print(\"by search:\")\n",
    "print(randint_node_srch)\n",
    "\n",
    "assert randint_node_nav == randint_node_srch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 switch\n",
      "1   hyperopt_param\n",
      "2     Literal{a}\n",
      "3     Literal{1}\n",
      "4   Literal{-1}\n",
      "5   float\n",
      "6     hyperopt_param\n",
      "7       Literal{b}\n",
      "8       uniform\n",
      "9         Literal{0}\n",
      "10         Literal{1}\n"
     ]
    }
   ],
   "source": [
    "space_with_fixed_a = pyll.clone(space, memo={randint_node_nav: pyll.as_apply(1)})\n",
    "print(space_with_fixed_a)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, having cloned the space with a new term for the randint, we can search the new space.  I wasn't sure if this would work because I haven't really tested the use of hyperopt_params that wrap around non-random nodes (here we replaced the randint with a constant) but it works for random search:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 1149.65trial/s, best loss: 0.0020179230078352095]\n",
      "{'a': 1, 'b': 0.0020179230078352095}\n"
     ]
    }
   ],
   "source": [
    "best = fmin(fn=lambda x: x, space=space_with_fixed_a, algo=rand.suggest, max_evals=100)\n",
    "print(best)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Yep, sure enough: The TPE implementation is broken by a hyperparameter that turns out to be a constant. At implementation time, that was not part of the plan."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/100 [00:00<?, ?trial/s, best loss=?]\n"
     ]
    },
    {
     "ename": "KeyError",
     "evalue": "'asarray'",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyError\u001B[0m                                  Traceback (most recent call last)",
      "\u001B[0;32m<ipython-input-6-8b19efbb73ec>\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[1;32m      1\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0mhyperopt\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mtpe\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 2\u001B[0;31m \u001B[0mbest\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mfmin\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mfn\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mlambda\u001B[0m \u001B[0mx\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mx\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mspace\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mspace_with_fixed_a\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0malgo\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mtpe\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msuggest\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mmax_evals\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m100\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m      3\u001B[0m \u001B[0mprint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mbest\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/anaconda3/lib/python3.7/site-packages/hyperopt/fmin.py\u001B[0m in \u001B[0;36mfmin\u001B[0;34m(fn, space, algo, max_evals, timeout, loss_threshold, trials, rstate, allow_trials_fmin, pass_expr_memo_ctrl, catch_eval_exceptions, verbose, return_argmin, points_to_evaluate, max_queue_len, show_progressbar)\u001B[0m\n\u001B[1;32m    507\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    508\u001B[0m     \u001B[0;31m# next line is where the fmin is actually executed\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 509\u001B[0;31m     \u001B[0mrval\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mexhaust\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    510\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    511\u001B[0m     \u001B[0;32mif\u001B[0m \u001B[0mreturn_argmin\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/anaconda3/lib/python3.7/site-packages/hyperopt/fmin.py\u001B[0m in \u001B[0;36mexhaust\u001B[0;34m(self)\u001B[0m\n\u001B[1;32m    328\u001B[0m     \u001B[0;32mdef\u001B[0m \u001B[0mexhaust\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    329\u001B[0m         \u001B[0mn_done\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mlen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrials\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 330\u001B[0;31m         \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrun\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmax_evals\u001B[0m \u001B[0;34m-\u001B[0m \u001B[0mn_done\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mblock_until_done\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0masynchronous\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    331\u001B[0m         \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtrials\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrefresh\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    332\u001B[0m         \u001B[0;32mreturn\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/anaconda3/lib/python3.7/site-packages/hyperopt/fmin.py\u001B[0m in \u001B[0;36mrun\u001B[0;34m(self, N, block_until_done)\u001B[0m\n\u001B[1;32m    264\u001B[0m                     \u001B[0;31m# processes orchestration\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    265\u001B[0m                     new_trials = algo(\n\u001B[0;32m--> 266\u001B[0;31m                         \u001B[0mnew_ids\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mdomain\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtrials\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrstate\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mrandint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;36m2\u001B[0m \u001B[0;34m**\u001B[0m \u001B[0;36m31\u001B[0m \u001B[0;34m-\u001B[0m \u001B[0;36m1\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    267\u001B[0m                     )\n\u001B[1;32m    268\u001B[0m                     \u001B[0;32massert\u001B[0m \u001B[0mlen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mnew_ids\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;34m>=\u001B[0m \u001B[0mlen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mnew_trials\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/anaconda3/lib/python3.7/site-packages/hyperopt/tpe.py\u001B[0m in \u001B[0;36msuggest\u001B[0;34m(new_ids, domain, trials, seed, prior_weight, n_startup_jobs, n_EI_candidates, gamma, verbose)\u001B[0m\n\u001B[1;32m    866\u001B[0m     \u001B[0;31m# use build_posterior_wrapper to create the pyll nodes\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    867\u001B[0m     observed, observed_loss, posterior = build_posterior_wrapper(\n\u001B[0;32m--> 868\u001B[0;31m         \u001B[0mdomain\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mprior_weight\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mgamma\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    869\u001B[0m     )\n\u001B[1;32m    870\u001B[0m     \u001B[0mtt\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mtime\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtime\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;34m-\u001B[0m \u001B[0mt0\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/anaconda3/lib/python3.7/site-packages/hyperopt/tpe.py\u001B[0m in \u001B[0;36mbuild_posterior_wrapper\u001B[0;34m(domain, prior_weight, gamma)\u001B[0m\n\u001B[1;32m    830\u001B[0m         \u001B[0mobserved_loss\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m\"vals\"\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    831\u001B[0m         \u001B[0mpyll\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mLiteral\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mgamma\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 832\u001B[0;31m         \u001B[0mpyll\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mLiteral\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mfloat\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mprior_weight\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    833\u001B[0m     )\n\u001B[1;32m    834\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/anaconda3/lib/python3.7/site-packages/hyperopt/tpe.py\u001B[0m in \u001B[0;36mbuild_posterior\u001B[0;34m(specs, prior_idxs, prior_vals, obs_idxs, obs_vals, obs_loss_idxs, obs_loss_vals, oloss_gamma, prior_weight)\u001B[0m\n\u001B[1;32m    708\u001B[0m                 \u001B[0mobs_below\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mobs_above\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mobs_memo\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mnode\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    709\u001B[0m                 \u001B[0maa\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0mmemo\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0ma\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0ma\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mnode\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mpos_args\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 710\u001B[0;31m                 \u001B[0mfn\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0madaptive_parzen_samplers\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mnode\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mname\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    711\u001B[0m                 \u001B[0mb_args\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0mobs_below\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mprior_weight\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;34m+\u001B[0m \u001B[0maa\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    712\u001B[0m                 \u001B[0mnamed_args\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m{\u001B[0m\u001B[0mkw\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mmemo\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0marg\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0;34m(\u001B[0m\u001B[0mkw\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0marg\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mnode\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mnamed_args\u001B[0m\u001B[0;34m}\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;31mKeyError\u001B[0m: 'asarray'"
     ]
    }
   ],
   "source": [
    "from hyperopt import tpe\n",
    "best = fmin(fn=lambda x: x, space=space_with_fixed_a, algo=tpe.suggest, max_evals=100)\n",
    "print(best)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The TPE algorithm works if we make a different replacement in the graph. If we replace the entire \"hyperopt_param\" node corresponding to hyperparameter \"a\", then it works fine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 352.69trial/s, best loss: 0.00037832452435874396]\n",
      "{'b': 0.00037832452435874396}\n"
     ]
    }
   ],
   "source": [
    "space_with_no_a = pyll.clone(space, memo={space.pos_args[0]: pyll.as_apply(1)})\n",
    "best = fmin(fn=lambda x: x, space=space_with_no_a, algo=tpe.suggest, max_evals=100)\n",
    "print(best)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
